import os
import pickle
import random
import sys

import h5py
import numpy as np
import torch
from torch.utils.data import Dataset

from Augmentations.Build_Augmentation import build_augmentation
from Datasets.Build_Dataloader import datasets
from Utils.Logger import print_log
from Utils.Misc import fps

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)


def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
    pc = pc / m
    return pc


def farthest_point_sample(point, npoint):
    """
    Input:
        xyz: pointcloud data, [N, D]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [npoint, D]
    """
    N, D = point.shape
    xyz = point[:, :3]
    centroids = np.zeros((npoint,))
    distance = np.ones((N,)) * 1e10
    farthest = np.random.randint(0, N)
    for i in range(npoint):
        centroids[i] = farthest
        centroid = xyz[farthest, :]
        dist = np.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = np.argmax(distance, -1)
    point = point[centroids.astype(np.int32)]
    return point


@datasets.register_module()
class ScanObjectNN(Dataset):
    def __init__(self, mode, data_path, num_points, transform, loop=1):
        super().__init__()
        self.subset = mode
        self.num_points = num_points
        self.data_path = data_path
        self.loop = loop
        self.transform = build_augmentation(transform)
        self.gravity_dim = 1

        if self.subset == 'train':
            h5 = h5py.File(os.path.join(self.data_path, 'training_objectdataset.h5'), 'r')
            self.points = np.array(h5['data']).astype(np.float32)
            self.labels = np.array(h5['label']).astype(int)
            h5.close()
        elif self.subset == 'test':
            h5 = h5py.File(os.path.join(self.data_path, 'test_objectdataset.h5'), 'r')
            self.points = np.array(h5['data']).astype(np.float32)
            self.labels = np.array(h5['label']).astype(int)
            h5.close()
        else:
            raise NotImplementedError()

        if self.subset == 'test':
            precomputed_path = os.path.join(
                data_path, f'{self.subset}_objectdataset_{self.num_points}_fps.pkl')
            if not os.path.exists(precomputed_path):
                points = torch.from_numpy(self.points).to(torch.float32).cuda()
                self.points = fps(points, self.num_points)[1].cpu().numpy()
                with open(precomputed_path, 'wb') as f:
                    pickle.dump(self.points, f)
                    print(f"{precomputed_path} saved successfully")
            else:
                with open(precomputed_path, 'rb') as f:
                    self.points = pickle.load(f)
                    print(f"{precomputed_path} load successfully")

        print_log('The size of %s data is {%d} x {%d}' % (self.subset, self.points.shape[0], self.loop), logger='ScanObjectNN')

    def __getitem__(self, index):
        index = index % self.points.shape[0]
        data_dict = {'name': f"{self.subset}: No.{index}"}

        point_set, label = self.points[index], self.labels[index]

        # point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])

        data_dict['coord'] = point_set[:, 0:3]
        data_dict['category'] = np.array(label, dtype=np.int64)

        point_set = self.transform(data_dict)

        if 'heights' in point_set.keys():
            point_set['features'] = torch.cat((point_set['features'], point_set['heights']), dim=1)
        else:
            point_set['features'] = torch.cat(
                [
                    point_set['features'],
                    point_set['coord'][:, self.gravity_dim:self.gravity_dim + 1] - point_set['coord'][:, self.gravity_dim:self.gravity_dim + 1].min()
                ], dim=-1
            )

        return point_set

    def __len__(self):
        return self.points.shape[0] * self.loop


# @datasets.register_module()
# class ScanObjectNN(Dataset):
#     def __init__(self, mode, data_path, num_points, transform, loop=1):
#         super().__init__()
#         self.subset = mode
#         self.num_points = num_points
#         self.data_path = data_path
#         self.uniform = True
#         self.loop = loop
#         self.transform = build_augmentation(transform)
#
#         if self.subset == 'train':
#             h5 = h5py.File(os.path.join(self.data_path, 'training_objectdataset.h5'), 'r')
#             self.points = np.array(h5['data']).astype(np.float32)
#             self.labels = np.array(h5['label']).astype(int)
#             h5.close()
#         elif self.subset == 'test':
#             h5 = h5py.File(os.path.join(self.data_path, 'test_objectdataset.h5'), 'r')
#             self.points = np.array(h5['data']).astype(np.float32)
#             self.labels = np.array(h5['label']).astype(int)
#             h5.close()
#         else:
#             raise NotImplementedError()
#
#         print_log('The size of %s data is {%d} x {%d}' % (self.subset, self.points.shape[0], self.loop), logger='ScanObjectNN')
#
#     def __getitem__(self, index):
#         index = index % self.points.shape[0]
#         data_dict = {'name': f"{self.subset}: No.{index}"}
#
#         point_set, label = self.points[index], self.labels[index]
#
#         # point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
#
#         if self.uniform:
#             point_set = farthest_point_sample(point_set, self.num_points)
#         else:
#             point_set = point_set[0:self.num_points, :]
#
#         data_dict['coord'] = point_set[:, 0:3]
#         data_dict['category'] = np.array(label, dtype=np.int64)
#
#         point_set = self.transform(data_dict)
#
#         return point_set
#
#     def __len__(self):
#         return self.points.shape[0] * self.loop


@datasets.register_module()
class ScanObjectNN_hardest(Dataset):
    def __init__(self, mode, data_path, num_points, transform, loop=1):
        super().__init__()
        self.subset = mode
        self.num_points = num_points
        self.data_path = data_path
        self.uniform = True
        self.loop = loop
        self.transform = build_augmentation(transform)
        self.gravity_dim = 1

        if self.subset == 'train':
            h5 = h5py.File(os.path.join(self.data_path, 'training_objectdataset_augmentedrot_scale75.h5'), 'r')
            self.points = np.array(h5['data']).astype(np.float32)
            self.labels = np.array(h5['label']).astype(int)
            h5.close()
        elif self.subset == 'test':
            h5 = h5py.File(os.path.join(self.data_path, 'test_objectdataset_augmentedrot_scale75.h5'), 'r')
            self.points = np.array(h5['data']).astype(np.float32)
            self.labels = np.array(h5['label']).astype(int)
            h5.close()
        else:
            raise NotImplementedError()

        if self.subset == 'test':
            precomputed_path = os.path.join(
                data_path, f'{self.subset}_objectdataset_augmentedrot_scale75_{self.num_points}_fps.pkl')
            if not os.path.exists(precomputed_path):
                points = torch.from_numpy(self.points).to(torch.float32).cuda()
                self.points = fps(points, self.num_points)[1].cpu().numpy()
                with open(precomputed_path, 'wb') as f:
                    pickle.dump(self.points, f)
                    print(f"{precomputed_path} saved successfully")
            else:
                with open(precomputed_path, 'rb') as f:
                    self.points = pickle.load(f)
                    print(f"{precomputed_path} load successfully")

        print_log('The size of %s data is {%d} x {%d}' % (self.subset, self.points.shape[0], self.loop), logger='ScanObjectNN')

    def __getitem__(self, index):
        index = index % self.points.shape[0]
        data_dict = {'name': f"{self.subset}: No.{index}"}

        point_set, label = self.points[index], self.labels[index]

        data_dict['coord'] = point_set[:, 0:3]
        data_dict['category'] = np.array(label, dtype=np.int64)

        point_set = self.transform(data_dict)

        if 'heights' in point_set.keys():
            point_set['features'] = torch.cat((point_set['features'], point_set['heights']), dim=1)
        else:
            point_set['features'] = torch.cat(
                (
                    point_set['features'],
                    point_set['coord'][:, self.gravity_dim:self.gravity_dim + 1] - point_set['coord'][:, self.gravity_dim:self.gravity_dim + 1].min()
                ), dim=-1
            )

        return point_set

    def __len__(self):
        return self.points.shape[0] * self.loop


# @datasets.register_module()
# class ScanObjectNN_hardest(Dataset):
#     def __init__(self, mode, data_path, num_points, transform, loop=1):
#         super().__init__()
#         self.subset = mode
#         self.num_points = num_points
#         self.data_path = data_path
#         self.uniform = True
#         self.loop = loop
#         self.transform = build_augmentation(transform)
#
#         if self.subset == 'train':
#             h5 = h5py.File(os.path.join(self.data_path, 'training_objectdataset_augmentedrot_scale75.h5'), 'r')
#             self.points = np.array(h5['data']).astype(np.float32)
#             self.labels = np.array(h5['label']).astype(int)
#             h5.close()
#         elif self.subset == 'test':
#             h5 = h5py.File(os.path.join(self.data_path, 'test_objectdataset_augmentedrot_scale75.h5'), 'r')
#             self.points = np.array(h5['data']).astype(np.float32)
#             self.labels = np.array(h5['label']).astype(int)
#             h5.close()
#         else:
#             raise NotImplementedError()
#
#         print_log('The size of %s data is {%d} x {%d}' % (self.subset, self.points.shape[0], self.loop), logger='ScanObjectNN')
#
#     def __getitem__(self, index):
#         index = index % self.points.shape[0]
#         data_dict = {'name': f"{self.subset}: No.{index}"}
#
#         point_set, label = self.points[index], self.labels[index]
#
#         point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
#
#         if self.uniform:
#             point_set = farthest_point_sample(point_set, self.num_points)
#         else:
#             point_set = point_set[0:self.num_points, :]
#
#         data_dict['coord'] = point_set[:, 0:3]
#         data_dict['category'] = np.array(label, dtype=np.int64)
#
#         point_set = self.transform(data_dict)
#
#         return point_set
#
#     def __len__(self):
#         return self.points.shape[0] * self.loop

def resplite():
    split = [83, 199, 133, 372, 390, 150, 204, 210, 241, 270, 110, 105, 120, 210, 85]
    num_class = len(split)
    data_path = '/home/klinqu/桌面/PointLearner/Data/scanobjectnn/main_split'
    h5 = h5py.File(os.path.join(data_path, 'training_objectdataset_augmentedrot_scale75.h5'), 'r')
    train_points = np.array(h5['data']).astype(np.float32)
    train_labels = np.array(h5['label']).astype(int)
    h5.close()

    h5 = h5py.File(os.path.join(data_path, 'test_objectdataset_augmentedrot_scale75.h5'), 'r')
    test_points = np.array(h5['data']).astype(np.float32)
    test_labels = np.array(h5['label']).astype(int)
    h5.close()

    points = np.concatenate([train_points, test_points], axis=0)
    labels = np.concatenate([train_labels, test_labels], axis=0)

    def find_positions(value, lst):
        positions = []
        for index, item in enumerate(lst):
            if item == value:
                positions.append(index)
        return positions

    train_index = []
    test_index = []

    for i, num in enumerate(split):
        index_subcategory = find_positions(i, labels)
        random.shuffle(index_subcategory)
        test_index.extend(index_subcategory[:num])
        train_index.extend(index_subcategory[num:])

    new_train_points = points[train_index]
    new_train_labels = labels[train_index]

    new_test_points = points[test_index]
    new_test_labels = labels[test_index]

    with h5py.File('/home/klinqu/桌面/PointLearner/Data/ScanObjectNN/main_split/' + 'training_objectdataset_augmentedrot_scale75.h5', 'w') as hf:
        # 创建一个数据集
        hf.create_dataset('data', data=new_train_points)
        # 创建一个标签数据集
        hf.create_dataset('label', data=new_train_labels)

    with h5py.File('/home/klinqu/桌面/PointLearner/Data/ScanObjectNN/main_split/' + 'test_objectdataset_augmentedrot_scale75.h5', 'w') as hf:
        # 创建一个数据集
        hf.create_dataset('data', data=new_test_points)
        # 创建一个标签数据集
        hf.create_dataset('label', data=new_test_labels)


if __name__ == '__main__':
    resplite()
